Optimising a TensorFlow SavedModel for Serving

This notebooks shows how to optimise the TensorFlow exported SavedModel by shrinking its size (to have less memory and disk footprints), and improving prediction latency. This can be accopmlished by applying the following:

  • Freezing: That is, converting the variables stored in a checkpoint file of the SavedModel into constants stored directly in the model graph.
  • Pruning: That is, stripping unused nodes during the prediction path of the graph, merging duplicate nodes, as well as removing other node ops like summary, identity, etc.
  • Quantisation: That is, converting any large float Const op into an eight-bit equivalent, followed by a float conversion op so that the result is usable by subsequent nodes.
  • Other refinements: That includes constant folding, batch_norm folding, fusing convolusion, etc.

The optimisation operations we apply in this example are from the TensorFlow Graph Conversion Tool, which is a c++ command-line tool. We use the Python APIs to call the c++ libraries.

The Graph Transform Tool is designed to work on models that are saved as GraphDef files, usually in a binary protobuf format. However, the model exported after training and estimator is in SavedModel format (saved_model.pb file + variables folder with variables.data-* and variables.index files).

We need to optimise the mode and keep it the SavedModel format. Thus, the optimisation steps will be:

  1. Freeze the SavedModel: SavedModel -> GraphDef
  2. Optimisae the freezed model: GraphDef -> GraphDef
  3. Convert the optimised freezed model to SavedModel: GraphDef -> SavedModel

In [1]:
import os
import numpy as np
from datetime import datetime

import tensorflow as tf

print "TensorFlow : {}".format(tf.__version__)


TensorFlow : 1.10.0

1. Train and Export a TensorFlow DNNClassifier

1.1 Import Data


In [2]:
mnist = tf.contrib.learn.datasets.load_dataset("mnist")
train_data = mnist.train.images
train_labels = np.asarray(mnist.train.labels, dtype=np.int32)
eval_data = mnist.test.images
eval_labels = np.asarray(mnist.test.labels, dtype=np.int32)
NUM_CLASSES = 10


WARNING:tensorflow:From <ipython-input-2-d053803e8488>:1: load_dataset (from tensorflow.contrib.learn.python.learn.datasets) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.data.
WARNING:tensorflow:From /Users/khalidsalama/Technology/python-venvs/py27-venv/lib/python2.7/site-packages/tensorflow/contrib/learn/python/learn/datasets/__init__.py:80: load_mnist (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
WARNING:tensorflow:From /Users/khalidsalama/Technology/python-venvs/py27-venv/lib/python2.7/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:300: read_data_sets (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
WARNING:tensorflow:From /Users/khalidsalama/Technology/python-venvs/py27-venv/lib/python2.7/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:260: maybe_download (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
Instructions for updating:
Please write your own downloading logic.
WARNING:tensorflow:From /Users/khalidsalama/Technology/python-venvs/py27-venv/lib/python2.7/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:262: extract_images (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting MNIST-data/train-images-idx3-ubyte.gz
WARNING:tensorflow:From /Users/khalidsalama/Technology/python-venvs/py27-venv/lib/python2.7/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:267: extract_labels (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting MNIST-data/train-labels-idx1-ubyte.gz
Extracting MNIST-data/t10k-images-idx3-ubyte.gz
Extracting MNIST-data/t10k-labels-idx1-ubyte.gz
WARNING:tensorflow:From /Users/khalidsalama/Technology/python-venvs/py27-venv/lib/python2.7/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:290: __init__ (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.

In [3]:
print "Train data shape: {}".format(train_data.shape)
print "Eval data shape: {}".format(eval_data.shape)


Train data shape: (55000, 784)
Eval data shape: (10000, 784)

1.2 Estimator

1.2.1 Model Function


In [4]:
def model_fn(features, labels, mode, params):
    
    is_training = True if mode == tf.estimator.ModeKeys.TRAIN else False

    # convolution layers
    def _cnn_layers(conv_inputs):

        for i in range(params.num_conv_layers):

            filters = params.init_filters * (2**i)
            
            conv = tf.keras.layers.Conv2D(kernel_size=3, filters=filters, strides=1, padding='SAME')(conv_inputs)
            pool = tf.keras.layers.MaxPool2D(pool_size=2, strides=2, padding='SAME')(conv)
            batch_norm = tf.keras.layers.BatchNormalization()(pool, training=is_training)  
            conv_inputs = batch_norm

        outputs = conv_inputs
        return outputs
    
    # fully-connected layers
    def _fully_connected_layers(dense_inputs):
        
        for i in range(len(params.hidden_units)):
            dense = tf.keras.layers.Dense(params.hidden_units[i], activation='relu')(dense_inputs)
            dense_dropout = tf.keras.layers.Dropout(params.dropout)(dense, training=is_training)
            dense_inputs = dense_dropout
            
        outputs = dense_inputs
        return outputs

    # model body
    def _inference(features, mode, params):
        
        input_layer = tf.keras.layers.Reshape([28, 28, 1])(features["input_image"])
        convolutions = _cnn_layers(input_layer)
        flatten = tf.keras.layers.Flatten()(convolutions)
        fully_connected = _fully_connected_layers(flatten)
        
        # unused_layer
        unused_layers = tf.keras.layers.Dense(units=100, name='unused', activation=None)(flatten)
        
        logits = tf.keras.layers.Dense(units=NUM_CLASSES, name='logits', activation=None)(fully_connected)
        return logits
    
    # model head
    head = tf.contrib.estimator.multi_class_head(n_classes=NUM_CLASSES)
    
    return head.create_estimator_spec(
            features=features,
            mode=mode,
            logits=_inference(features, mode, params),
            labels=labels,
            optimizer=tf.train.AdamOptimizer(params.learning_rate)
    )

1.2.2 Create Custom Estimator


In [5]:
def create_estimator(params, run_config):

    # evaluation metric_fn
    def _metric_fn(labels, predictions):

        metrics = {}
        pred_class = predictions['class_ids']
        metrics['micro_accuracy'] = tf.metrics.mean_per_class_accuracy(
            labels=labels, predictions=pred_class, num_classes=NUM_CLASSES
        )

        return metrics

    mnist_classifier = tf.estimator.Estimator(
        model_fn=model_fn, params=params, config=run_config)

    mnist_classifier = tf.contrib.estimator.add_metrics(
        estimator=mnist_classifier, metric_fn=_metric_fn)
    
    return mnist_classifier

1.3 Train and Evaluate

1.3.1 Experiment Function


In [6]:
def run_experiment(hparam, run_config):
    
    train_spec = tf.estimator.TrainSpec(
        input_fn = tf.estimator.inputs.numpy_input_fn(
            x={"input_image": train_data},
            y=train_labels,
            batch_size=hparam.batch_size,
            num_epochs=None,
            shuffle=True),
        max_steps=hparams.max_traning_steps
    )

    eval_spec = tf.estimator.EvalSpec(
        input_fn = tf.estimator.inputs.numpy_input_fn(
            x={"input_image": eval_data},
            y=eval_labels,
            batch_size=hparam.batch_size,
            num_epochs=1,
            shuffle=False),
        steps=None,
        throttle_secs=hparams.eval_throttle_secs
    )

    tf.logging.set_verbosity(tf.logging.INFO)

    time_start = datetime.utcnow() 
    print("Experiment started at {}".format(time_start.strftime("%H:%M:%S")))
    print(".......................................") 

    estimator = create_estimator(hparams, run_config)

    tf.estimator.train_and_evaluate(
        estimator=estimator,
        train_spec=train_spec, 
        eval_spec=eval_spec
    )

    time_end = datetime.utcnow() 
    print(".......................................")
    print("Experiment finished at {}".format(time_end.strftime("%H:%M:%S")))
    print("")
    time_elapsed = time_end - time_start
    print("Experiment elapsed time: {} seconds".format(time_elapsed.total_seconds()))
    
    return estimator

1.3.2 Experiment Parameters


In [9]:
MODELS_LOCATION = 'models/mnist'
MODEL_NAME = 'cnn_classifier'
model_dir = os.path.join(MODELS_LOCATION, MODEL_NAME)

print model_dir

hparams  = tf.contrib.training.HParams(
    batch_size=100,
    hidden_units=[512, 512],
    num_conv_layers=3, 
    init_filters=64,
    dropout=0.2,
    max_traning_steps=50,
    eval_throttle_secs=10,
    learning_rate=1e-3
)

run_config = tf.estimator.RunConfig(
    tf_random_seed=19830610,
    save_checkpoints_steps=1000,
    keep_checkpoint_max=3,
    model_dir=model_dir
)


models/mnist/cnn_classifier

TensorFlow Graph

1.3.3 Run Experiment


In [10]:
if tf.gfile.Exists(model_dir):
    print("Removing previous artifacts...")
    tf.gfile.DeleteRecursively(model_dir)
    
estimator = run_experiment(hparams, run_config)


Removing previous artifacts...
Experiment started at 20:53:25
.......................................
INFO:tensorflow:Using config: {'_save_checkpoints_secs': None, '_global_id_in_cluster': 0, '_session_config': None, '_keep_checkpoint_max': 3, '_tf_random_seed': 19830610, '_task_type': 'worker', '_train_distribute': None, '_is_chief': True, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x117098810>, '_model_dir': 'models/mnist/cnn_classifier', '_num_worker_replicas': 1, '_task_id': 0, '_log_step_count_steps': 100, '_master': '', '_save_checkpoints_steps': 1000, '_keep_checkpoint_every_n_hours': 10000, '_evaluation_master': '', '_service': None, '_device_fn': None, '_save_summary_steps': 100, '_num_ps_replicas': 0}
INFO:tensorflow:Using config: {'_save_checkpoints_secs': None, '_session_config': None, '_keep_checkpoint_max': 3, '_task_type': 'worker', '_global_id_in_cluster': 0, '_is_chief': True, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x11650d750>, '_evaluation_master': '', '_save_checkpoints_steps': 1000, '_keep_checkpoint_every_n_hours': 10000, '_service': None, '_num_ps_replicas': 0, '_tf_random_seed': 19830610, '_master': '', '_device_fn': None, '_num_worker_replicas': 1, '_task_id': 0, '_log_step_count_steps': 100, '_model_dir': 'models/mnist/cnn_classifier', '_train_distribute': None, '_save_summary_steps': 100}
INFO:tensorflow:Running training and evaluation locally (non-distributed).
INFO:tensorflow:Start train and evaluate loop. The evaluate will happen after every checkpoint. Checkpoint frequency is determined based on RunConfig arguments: save_checkpoints_steps 1000 or save_checkpoints_secs None.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Saving checkpoints for 0 into models/mnist/cnn_classifier/model.ckpt.
INFO:tensorflow:loss = 3.1126516, step = 1
INFO:tensorflow:Saving checkpoints for 50 into models/mnist/cnn_classifier/model.ckpt.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2018-10-07-20:53:52
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from models/mnist/cnn_classifier/model.ckpt-50
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Finished evaluation at 2018-10-07-20:54:05
INFO:tensorflow:Saving dict for global step 50: accuracy = 0.7091, average_loss = 1.8880125, global_step = 50, loss = 1.8880123, micro_accuracy = 0.71909255
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 50: models/mnist/cnn_classifier/model.ckpt-50
INFO:tensorflow:Loss for final step: 0.3911407.
.......................................
Experiment finished at 20:54:05

Experiment elapsed time: 40.084428 seconds

1.4 Export the model


In [11]:
def make_serving_input_receiver_fn():
    inputs = {'input_image': tf.placeholder(shape=[None,784], dtype=tf.float32, name='input_image')}
    return tf.estimator.export.build_raw_serving_input_receiver_fn(inputs)

export_dir = os.path.join(model_dir, 'export')

if tf.gfile.Exists(export_dir):
    tf.gfile.DeleteRecursively(export_dir)
        
estimator.export_savedmodel(
    export_dir_base=export_dir,
    serving_input_receiver_fn=make_serving_input_receiver_fn()
)


INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Signatures INCLUDED in export for Eval: None
INFO:tensorflow:Signatures INCLUDED in export for Classify: None
INFO:tensorflow:Signatures INCLUDED in export for Regress: None
INFO:tensorflow:Signatures INCLUDED in export for Predict: ['predict']
INFO:tensorflow:Signatures INCLUDED in export for Train: None
INFO:tensorflow:Signatures EXCLUDED from export because they cannot be be served via TensorFlow Serving APIs:
INFO:tensorflow:'serving_default' : Classification input must be a single string Tensor; got {'input_image': <tf.Tensor 'input_image:0' shape=(?, 784) dtype=float32>}
INFO:tensorflow:'classification' : Classification input must be a single string Tensor; got {'input_image': <tf.Tensor 'input_image:0' shape=(?, 784) dtype=float32>}
WARNING:tensorflow:Export includes no default signature!
INFO:tensorflow:Restoring parameters from models/mnist/cnn_classifier/model.ckpt-50
INFO:tensorflow:Assets added to graph.
INFO:tensorflow:No assets to write.
INFO:tensorflow:SavedModel written to: models/mnist/cnn_classifier/export/temp-1538945645/saved_model.pb
Out[11]:
'models/mnist/cnn_classifier/export/1538945645'

2. Inspect the Exported SavedModel


In [12]:
%%bash

saved_models_base=models/mnist/cnn_classifier/export/
saved_model_dir=${saved_models_base}$(ls ${saved_models_base} | tail -n 1)
echo ${saved_model_dir}
ls ${saved_model_dir}
saved_model_cli show --dir=${saved_model_dir} --all


models/mnist/cnn_classifier/export/1538945645
saved_model.pb
variables

MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs:

signature_def['predict']:
  The given SavedModel SignatureDef contains the following input(s):
    inputs['input_image'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1, 784)
        name: input_image:0
  The given SavedModel SignatureDef contains the following output(s):
    outputs['class_ids'] tensor_info:
        dtype: DT_INT64
        shape: (-1, 1)
        name: head/predictions/ExpandDims:0
    outputs['classes'] tensor_info:
        dtype: DT_STRING
        shape: (-1, 1)
        name: head/predictions/str_classes:0
    outputs['logits'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1, 10)
        name: logits/BiasAdd:0
    outputs['probabilities'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1, 10)
        name: head/predictions/probabilities:0
  Method name is: tensorflow/serving/predict

Prediction with SavedModel


In [13]:
def inference_test(saved_model_dir, signature="predict", input_name='input_image', batch=300, repeat=100):

    tf.logging.set_verbosity(tf.logging.ERROR)
    
    time_start = datetime.utcnow() 
    
    predictor = tf.contrib.predictor.from_saved_model(
        export_dir = saved_model_dir,
        signature_def_key=signature
    )
    time_end = datetime.utcnow() 
        
    time_elapsed = time_end - time_start
   
    print ""
    print("Model loading time: {} seconds".format(time_elapsed.total_seconds()))
    print ""
    
    time_start = datetime.utcnow() 
    output = None
    for i in range(repeat):
        output = predictor(
            {
                input_name: eval_data[:batch]
            }
        )
    
    time_end = datetime.utcnow() 

    time_elapsed_sec = (time_end - time_start).total_seconds()
    
    print "Inference elapsed time: {} seconds".format(time_elapsed_sec)
    print ""
    
    print "Prediction produced for {} instances batch, repeated {} times".format(len(output['class_ids']), repeat)
    print "Average latency per batch: {} seconds".format(time_elapsed_sec/repeat)
    print ""
    
    print "Prediction output for the last instance:"
    for key in output.keys():
        print "{}: {}".format(key,output[key][0])

3. Test Prediction with SavedModel


In [14]:
saved_model_dir = os.path.join(export_dir, os.listdir(export_dir)[-1]) 
print(saved_model_dir)
inference_test(saved_model_dir)


models/mnist/cnn_classifier/export/1538945645

Model loading time: 0.131457 seconds

Inference elapsed time: 37.999269 seconds

Prediction produced for 300 instances batch, repeated 100 times
Average latency per batch: 0.37999269 seconds

Prediction output for the last instance:
probabilities: [0.08772983 0.06812517 0.1355336  0.1237169  0.06289934 0.08402904
 0.05941513 0.1848294  0.09718301 0.09653859]
class_ids: [7]
classes: ['7']
logits: [-0.05797689 -0.31089228  0.3769806   0.28575695 -0.3907033  -0.10107648
 -0.44768998  0.6871943   0.044357    0.03770389]

Describe GraphDef


In [15]:
def describe_graph(graph_def, show_nodes=False):
    
    print 'Input Feature Nodes: {}'.format([node.name for node in graph_def.node if node.op=='Placeholder'])
    print ""
    print 'Unused Nodes: {}'.format([node.name for node in graph_def.node if 'unused'  in node.name])
    print ""
    print 'Output Nodes: {}'.format( [node.name for node in graph_def.node if 'predictions' in node.name])
    print ""
    print 'Quanitization Nodes: {}'.format( [node.name for node in graph_def.node if 'quant' in node.name])
    print ""
    print 'Constant Count: {}'.format( len([node for node in graph_def.node if node.op=='Const']))
    print ""
    print 'Variable Count: {}'.format( len([node for node in graph_def.node if 'Variable' in node.op]))
    print ""
    print 'Identity Count: {}'.format( len([node for node in graph_def.node if node.op=='Identity']))
    print ""
    print 'Total nodes: {}'.format( len(graph_def.node))
    print ''
    
    if show_nodes==True:
        for node in graph_def.node:
            print 'Op:{} - Name: {}'.format(node.op, node.name)

4. Describe the SavedModel Graph (before optimisation)

Load GraphDef from a SavedModel Directory


In [16]:
def get_graph_def_from_saved_model(saved_model_dir):
    
    print saved_model_dir
    print ""
    
    from tensorflow.python.saved_model import tag_constants
    
    with tf.Session() as session:
        meta_graph_def = tf.saved_model.loader.load(
            session,
            tags=[tag_constants.SERVING],
            export_dir=saved_model_dir
        )
        
    return meta_graph_def.graph_def

In [17]:
describe_graph(get_graph_def_from_saved_model(saved_model_dir))


models/mnist/cnn_classifier/export/1538945645

Input Feature Nodes: [u'input_image']

Unused Nodes: [u'unused/kernel/Initializer/random_uniform/shape', u'unused/kernel/Initializer/random_uniform/min', u'unused/kernel/Initializer/random_uniform/max', u'unused/kernel/Initializer/random_uniform/RandomUniform', u'unused/kernel/Initializer/random_uniform/sub', u'unused/kernel/Initializer/random_uniform/mul', u'unused/kernel/Initializer/random_uniform', u'unused/kernel', u'unused/kernel/IsInitialized/VarIsInitializedOp', u'unused/kernel/Assign', u'unused/kernel/Read/ReadVariableOp', u'unused/bias/Initializer/zeros', u'unused/bias', u'unused/bias/IsInitialized/VarIsInitializedOp', u'unused/bias/Assign', u'unused/bias/Read/ReadVariableOp', u'unused/MatMul/ReadVariableOp', u'unused/MatMul', u'unused/BiasAdd/ReadVariableOp', u'unused/BiasAdd']

Output Nodes: [u'head/predictions/class_ids/dimension', u'head/predictions/class_ids', u'head/predictions/ExpandDims/dim', u'head/predictions/ExpandDims', u'head/predictions/str_classes', u'head/predictions/probabilities']

Quanitization Nodes: []

Constant Count: 76

Variable Count: 105

Identity Count: 31

Total nodes: 351

Get model size


In [18]:
def get_size(model_dir):
    
    print model_dir
    print ""
    
    pb_size = os.path.getsize(os.path.join(model_dir,'saved_model.pb'))
    
    variables_size = 0
    if os.path.exists(os.path.join(model_dir,'variables/variables.data-00000-of-00001')):
        variables_size = os.path.getsize(os.path.join(model_dir,'variables/variables.data-00000-of-00001'))
        variables_size += os.path.getsize(os.path.join(model_dir,'variables/variables.index'))

    print "Model size: {} KB".format(round(pb_size/(1024.0),3))
    print "Variables size: {} KB".format(round( variables_size/(1024.0),3))
    print "Total Size: {} KB".format(round((pb_size + variables_size)/(1024.0),3))

In [19]:
get_size(saved_model_dir)


models/mnist/cnn_classifier/export/1538945645

Model size: 64.605 KB
Variables size: 12292.438 KB
Total Size: 12357.043 KB

5. Freeze SavedModel

This function will convert the SavedModel into a GraphDef file (freezed_model.pb), and storing the variables as constrant to the freezed_model.pb

You need to define the graph output nodes for freezing. We are only interested in the class_id, which is produced by head/predictions/ExpandDims node


In [20]:
def freeze_graph(saved_model_dir):
    
    from tensorflow.python.tools import freeze_graph
    from tensorflow.python.saved_model import tag_constants
    
    output_graph_filename = os.path.join(saved_model_dir, "freezed_model.pb")
    output_node_names = "head/predictions/ExpandDims"
    initializer_nodes = ""

    freeze_graph.freeze_graph(
        input_saved_model_dir=saved_model_dir,
        output_graph=output_graph_filename,
        saved_model_tags = tag_constants.SERVING,
        output_node_names=output_node_names,
        initializer_nodes=initializer_nodes,

        input_graph=None, 
        input_saver=False,
        input_binary=False, 
        input_checkpoint=None, 
        restore_op_name=None, 
        filename_tensor_name=None, 
        clear_devices=False,
        input_meta_graph=False,
    )
    
    print "SavedModel graph freezed!"

In [21]:
freeze_graph(saved_model_dir)


SavedModel graph freezed!

In [22]:
%%bash
saved_models_base=models/mnist/cnn_classifier/export/
saved_model_dir=${saved_models_base}$(ls ${saved_models_base} | tail -n 1)
echo ${saved_model_dir}
ls ${saved_model_dir}


models/mnist/cnn_classifier/export/1538945645
freezed_model.pb
saved_model.pb
variables

6. Describe the freezed_model.pb Graph (after freezing)

Load GraphDef from GraphDef File


In [23]:
def get_graph_def_from_file(graph_filepath):
    
    print graph_filepath
    print ""
    
    from tensorflow.python import ops
    
    with ops.Graph().as_default():
        with tf.gfile.GFile(graph_filepath, "rb") as f:
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(f.read())
            
            return graph_def

In [24]:
freezed_filepath=os.path.join(saved_model_dir,'freezed_model.pb')
describe_graph(get_graph_def_from_file(freezed_filepath))


models/mnist/cnn_classifier/export/1538945645/freezed_model.pb

Input Feature Nodes: [u'input_image']

Unused Nodes: []

Output Nodes: [u'head/predictions/class_ids/dimension', u'head/predictions/class_ids', u'head/predictions/ExpandDims/dim', u'head/predictions/ExpandDims']

Quanitization Nodes: []

Constant Count: 36

Variable Count: 0

Identity Count: 26

Total nodes: 93

8. Optimise the freezed_model.pb

Optimise GraphDef


In [25]:
def optimize_graph(model_dir, graph_filename, transforms):
    
    from tensorflow.tools.graph_transforms import TransformGraph
    
    input_names = []
    output_names = ['head/predictions/ExpandDims']
    
    graph_def = get_graph_def_from_file(os.path.join(model_dir, graph_filename))
    optimised_graph_def = TransformGraph(graph_def, 
                                         input_names,
                                         output_names,
                                         transforms 
                                        )
    tf.train.write_graph(optimised_graph_def,
                        logdir=model_dir,
                        as_text=False,
                        name='optimised_model.pb')
    
    print "Freezed graph optimised!"

In [26]:
transforms = [
    'remove_nodes(op=Identity)', 
    'fold_constants(ignore_errors=true)',
    'fold_batch_norms',
#    'fuse_resize_pad_and_conv',
#    'quantize_weights',
#    'quantize_nodes',
    'merge_duplicate_nodes',
    'strip_unused_nodes', 
    'sort_by_execution_order'
]

optimize_graph(saved_model_dir, 'freezed_model.pb', transforms)


models/mnist/cnn_classifier/export/1538945645/freezed_model.pb

Freezed graph optimised!

In [27]:
%%bash
saved_models_base=models/mnist/cnn_classifier/export/
saved_model_dir=${saved_models_base}$(ls ${saved_models_base} | tail -n 1)
echo ${saved_model_dir}
ls ${saved_model_dir}


models/mnist/cnn_classifier/export/1538945645
freezed_model.pb
optimised_model.pb
saved_model.pb
variables

8. Describe the Optimised Graph


In [28]:
optimised_filepath=os.path.join(saved_model_dir,'optimised_model.pb')
describe_graph(get_graph_def_from_file(optimised_filepath))


models/mnist/cnn_classifier/export/1538945645/optimised_model.pb

Input Feature Nodes: [u'input_image']

Unused Nodes: []

Output Nodes: [u'head/predictions/class_ids', u'head/predictions/ExpandDims']

Quanitization Nodes: []

Constant Count: 29

Variable Count: 0

Identity Count: 0

Total nodes: 60

9. Convert Optimised graph (GraphDef) to SavedModel


In [29]:
def convert_graph_def_to_saved_model(graph_filepath):

    from tensorflow.python import ops
    export_dir=os.path.join(saved_model_dir,'optimised')

    if tf.gfile.Exists(export_dir):
        tf.gfile.DeleteRecursively(export_dir)

    graph_def = get_graph_def_from_file(graph_filepath)
    
    with tf.Session(graph=tf.Graph()) as session:
        tf.import_graph_def(graph_def, name="")
        tf.saved_model.simple_save(session,
                export_dir,
                inputs={
                    node.name: session.graph.get_tensor_by_name("{}:0".format(node.name)) 
                    for node in graph_def.node if node.op=='Placeholder'},
                outputs={
                    "class_ids": session.graph.get_tensor_by_name("head/predictions/ExpandDims:0"),
                }
            )

        print "Optimised graph converted to SavedModel!"

In [30]:
optimised_filepath=os.path.join(saved_model_dir,'optimised_model.pb')
convert_graph_def_to_saved_model(optimised_filepath)


models/mnist/cnn_classifier/export/1538945645/optimised_model.pb

Optimised graph converted to SavedModel!

Optimised SavedModel Size


In [31]:
optimised_saved_model_dir = os.path.join(saved_model_dir,'optimised') 
get_size(optimised_saved_model_dir)


models/mnist/cnn_classifier/export/1538945645/optimised

Model size: 10701.646 KB
Variables size: 0.0 KB
Total Size: 10701.646 KB

In [32]:
%%bash

saved_models_base=models/mnist/cnn_classifier/export/
saved_model_dir=${saved_models_base}$(ls ${saved_models_base} | tail -n 1)/optimised
ls ${saved_model_dir}
saved_model_cli show --dir ${saved_model_dir} --all


saved_model.pb
variables

MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs:

signature_def['serving_default']:
  The given SavedModel SignatureDef contains the following input(s):
    inputs['input_image'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1, 784)
        name: input_image:0
  The given SavedModel SignatureDef contains the following output(s):
    outputs['class_ids'] tensor_info:
        dtype: DT_INT64
        shape: (-1, 1)
        name: head/predictions/ExpandDims:0
  Method name is: tensorflow/serving/predict

10. Prediction with the Optimised SavedModel


In [33]:
optimised_saved_model_dir = os.path.join(saved_model_dir,'optimised') 
print(optimised_saved_model_dir)
inference_test(saved_model_dir=optimised_saved_model_dir, signature='serving_default', input_name='input_image')


models/mnist/cnn_classifier/export/1538945645/optimised

Model loading time: 0.13307 seconds

Inference elapsed time: 31.750448 seconds

Prediction produced for 300 instances batch, repeated 100 times
Average latency per batch: 0.31750448 seconds

Prediction output for the last instance:
class_ids: [7]

Cloud ML Engine Deployment and Prediction


In [ ]:
PROJECT = 'ksalama-gcp-playground'
BUCKET = 'ksalama-gcs-cloudml'
REGION = 'europe-west1'
MODEL_NAME = 'mnist_classifier'

os.environ['BUCKET'] = BUCKET
os.environ['PROJECT'] = PROJECT
os.environ['REGION'] = REGION
os.environ['MODEL_NAME'] = MODEL_NAME

1. Upload the model artefacts to Google Cloud Storage bucket


In [ ]:
%%bash

gsutil -m rm -r gs://${BUCKET}/tf-model-optimisation

In [ ]:
%%bash

saved_models_base=models/mnist/cnn_classifier/export/
saved_model_dir=${saved_models_base}$(ls ${saved_models_base} | tail -n 1)

echo ${saved_model_dir}

gsutil -m cp -r ${saved_model_dir} gs://${BUCKET}/tf-model-optimisation/original

In [ ]:
%%bash

saved_models_base=models/mnist/cnn_classifier/export/
saved_model_dir=${saved_models_base}$(ls ${saved_models_base} | tail -n 1)/optimised

echo ${saved_model_dir}

gsutil -m cp -r ${saved_model_dir} gs://${BUCKET}/tf-model-optimisation

2. Deploy models to Cloud ML Engine

Don't forget to delete the model and the model version if they were previously deployed!


In [ ]:
%%bash

echo ${MODEL_NAME}

gcloud ml-engine models create ${MODEL_NAME} --regions=${REGION}

Version: v_org is the original SavedModel (before optimisation)


In [ ]:
%%bash

MODEL_VERSION='v_org'
MODEL_ORIGIN=gs://${BUCKET}/tf-model-optimisation/original

gcloud ml-engine versions create ${MODEL_VERSION}\
            --model=${MODEL_NAME} \
            --origin=${MODEL_ORIGIN} \
            --runtime-version=1.10

Version: v_opt is the optimised SavedModel (after optimisation)


In [ ]:
%%bash

MODEL_VERSION='v_opt'
MODEL_ORIGIN=gs://${BUCKET}/tf-model-optimisation/optimised

gcloud ml-engine versions create ${MODEL_VERSION}\
            --model=${MODEL_NAME} \
            --origin=${MODEL_ORIGIN} \
            --runtime-version=1.10

3. Cloud ML Engine online predictions


In [ ]:
from googleapiclient import discovery
from oauth2client.client import GoogleCredentials

credentials = GoogleCredentials.get_application_default()
api = discovery.build(
    'ml', 'v1', 
    credentials=credentials, 
    discoveryServiceUrl='https://storage.googleapis.com/cloud-ml/discovery/ml_v1_discovery.json'
)

    
def predict(version, instances):

    request_data = {'instances': instances}

    model_url = 'projects/{}/models/{}/versions/{}'.format(PROJECT, MODEL_NAME, version)
    response = api.projects().predict(body=request_data, name=model_url).execute()

    class_ids = None
    
    try:
        class_ids = [item["class_ids"] for item in response["predictions"]]
    except:
        print response
    
    return class_ids

In [ ]:
def inference_cmle(version, batch=100, repeat=10):
    
    instances = [
            {'input_image': [float(i) for i in list(eval_data[img])] }
        for img in range(batch)
    ]

    #warmup request
    predict(version, instances[0])
    print 'Warm up request performed!'
    print 'Timer started...'
    print ''
    
    time_start = datetime.utcnow() 
    output = None
    
    for i in range(repeat):
        output = predict(version, instances)
    
    time_end = datetime.utcnow() 

    time_elapsed_sec = (time_end - time_start).total_seconds()
    
    print "Inference elapsed time: {} seconds".format(time_elapsed_sec)
    print ""
    
    print "Prediction produced for {} instances batch, repeated {} times".format(len(output), repeat)
    print "Average latency per batch: {} seconds".format(time_elapsed_sec/repeat)
    print ""
    
    print "Prediction output for the last instance: {}".format(output[0])

In [ ]:
version='v_org'
inference_cmle(version)

In [ ]:
version='v_opt'
inference_cmle(version)

Happy serving!


In [ ]: